// Copyright 2021 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.

package errwrap

import (
	"fmt"
	"go/ast"
	"go/constant"
	"go/types"
	"regexp"
	"strings"

	"github.com/cockroachdb/cockroach/pkg/testutils/lint/passes/passesutil"
	"golang.org/x/tools/go/analysis"
	"golang.org/x/tools/go/analysis/passes/inspect"
	"golang.org/x/tools/go/ast/inspector"
)

// Doc documents this pass.
const Doc = `checks for unwrapped errors.

This linter checks that:

- err.Error() is not passed as an argument to an error-creating
  function.

- the '%s', '%v', and '%+v' format verbs are not used to format
  errors when creating a new error.

In both cases, an error-wrapping function can be used to correctly
preserve the chain of errors so that user-directed hints, links to
documentation issues, and telemetry data are all propagated.

It is possible for a call site to opt the format/message string
out of the linter using /* nolint:errwrap */ on or before the line
that creates the error.`

var errorType = types.Universe.Lookup("error").Type().String()

// Analyzer checks for improperly wrapped errors.
var Analyzer = &analysis.Analyzer{
	Name:     "errwrap",
	Doc:      Doc,
	Requires: []*analysis.Analyzer{inspect.Analyzer},
	Run:      run,
}

func run(pass *analysis.Pass) (interface{}, error) {
	inspctr := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
	nodeFilter := []ast.Node{
		(*ast.CallExpr)(nil),
	}

	inspctr.Preorder(nodeFilter, func(n ast.Node) {
		// Catch-all for possible bugs in the linter code.
		defer func() {
			if r := recover(); r != nil {
				if err, ok := r.(error); ok {
					pass.Reportf(n.Pos(), "internal linter error: %v", err)
					return
				}
				panic(r)
			}
		}()

		callExpr, ok := n.(*ast.CallExpr)
		if !ok {
			return
		}
		if pass.TypesInfo.TypeOf(callExpr).String() != errorType {
			return
		}
		sel, ok := callExpr.Fun.(*ast.SelectorExpr)
		if !ok {
			return
		}
		obj, ok := pass.TypesInfo.Uses[sel.Sel]
		if !ok {
			return
		}
		fn, ok := obj.(*types.Func)
		if !ok {
			return
		}
		pkg := obj.Pkg()
		if pkg == nil {
			return
		}

		// Skip files generated by go-bindata.
		file := pass.Fset.File(n.Pos())
		if strings.HasSuffix(file.Name(), "/embedded.go") {
			return
		}
		fnName := stripVendor(fn.FullName())

		// Check that none of the arguments are err.Error()
		if _, found := ErrorFnFormatStringIndex[fnName]; found {
			for i := range callExpr.Args {
				if isErrorStringCall(pass, callExpr.Args[i]) {
					// If the argument is opting out of the linter with a special
					// comment, tolerate that.
					if passesutil.HasNolintComment(pass, sel, "errwrap") {
						continue
					}

					pass.Report(analysis.Diagnostic{
						Pos: n.Pos(),
						Message: fmt.Sprintf(
							"err.Error() is passed to %s.%s; use pgerror.Wrap/errors.Wrap/errors.CombineErrors/"+
								"errors.WithSecondaryError/errors.NewAssertionErrorWithWrappedErrf instead",
							pkg.Name(), fn.Name()),
					})
				}
			}
		}

		// Check that the format string does not use %s or %v for an error.
		formatStringIdx, ok := ErrorFnFormatStringIndex[fnName]
		if !ok || formatStringIdx < 0 {
			// Not an error formatting function.
			return
		}

		// Find all % fields in the format string.
		formatVerbs, ok := getFormatStringVerbs(pass, callExpr, formatStringIdx)
		if !ok {
			return
		}

		// For any arguments that are errors, check whether the wrapping verb
		// is %s or %v.
		args := callExpr.Args[formatStringIdx+1:]
		for i := 0; i < len(args) && i < len(formatVerbs); i++ {
			if pass.TypesInfo.TypeOf(args[i]).String() != errorType {
				continue
			}

			if formatVerbs[i] == "%v" || formatVerbs[i] == "%+v" || formatVerbs[i] == "%s" {
				// If the argument is opting out of the linter with a special
				// comment, tolerate that.
				if passesutil.HasNolintComment(pass, sel, "errwrap") {
					continue
				}

				pass.Report(analysis.Diagnostic{
					Pos: n.Pos(),
					Message: fmt.Sprintf(
						"non-wrapped error is passed to %s.%s; use pgerror.Wrap/errors.Wrap/errors.CombineErrors/"+
							"errors.WithSecondaryError/errors.NewAssertionErrorWithWrappedErrf instead",
						pkg.Name(), fn.Name(),
					),
				})
			}
		}
	})

	return nil, nil
}

// isErrorStringCall tests whether the expression is a string expression that
// is the result of an `(error).Error()` method call.
func isErrorStringCall(pass *analysis.Pass, expr ast.Expr) bool {
	if call, ok := expr.(*ast.CallExpr); ok {
		if pass.TypesInfo.TypeOf(call).String() == "string" {
			if callSel, ok := call.Fun.(*ast.SelectorExpr); ok {
				fun := pass.TypesInfo.Uses[callSel.Sel].(*types.Func)
				return fun.Type().String() == "func() string" && fun.Name() == "Error"
			}
		}
	}
	return false
}

// formatVerbRegexp naively matches format string verbs. This does not take
// modifiers such as padding into account.
var formatVerbRegexp = regexp.MustCompile(`%([^%+]|\+v)`)

// getFormatStringVerbs return an array of all `%` format verbs from the format
// string argument of a function call.
// Based on https://github.com/polyfloyd/go-errorlint/blob/e4f368f0ae6983eb40821ba4f88dc84ac51aef5b/errorlint/lint.go#L88
func getFormatStringVerbs(
	pass *analysis.Pass, call *ast.CallExpr, formatStringIdx int,
) ([]string, bool) {
	if len(call.Args) <= formatStringIdx {
		return nil, false
	}
	strLit, ok := call.Args[formatStringIdx].(*ast.BasicLit)
	if !ok {
		// Ignore format strings that are not literals.
		return nil, false
	}
	formatString := constant.StringVal(pass.TypesInfo.Types[strLit].Value)

	return formatVerbRegexp.FindAllString(formatString, -1), true
}

func stripVendor(s string) string {
	if i := strings.Index(s, "/vendor/"); i != -1 {
		s = s[i+len("/vendor/"):]
	}
	return s
}
